Skip to content

feat(AMD): support AMD GPUs (ROCm/HIP)#354

Open
RixinLiu wants to merge 8 commits into
mainfrom
amd-support-port
Open

feat(AMD): support AMD GPUs (ROCm/HIP)#354
RixinLiu wants to merge 8 commits into
mainfrom
amd-support-port

Conversation

@RixinLiu
Copy link
Copy Markdown
Collaborator

@RixinLiu RixinLiu commented Jun 3, 2026

Summary

Ports kvcached to AMD GPUs (ROCm/HIP). The KV cache, its C++ page allocator, and the GPU virtual-memory operations that grow/shrink it now build and run on both NVIDIA (CUDA driver API) and AMD (HIP), selected at compile time. Validated end-to-end on an AMD Instinct MI300X (ROCm 7.0): builds, serves dense models with output identical to the no-kvcached baseline, and the elastic KV cache physically grows and shrinks under load. NVIDIA is unaffected (CUDA non-regression checked).

Provenance

The AMD support was originally developed on the amd-support-init branch. Merging that branch directly into main was not reasonable (it had diverged), so its changes were re-applied as fresh, self-contained commits on this branch (amd-support-port, branched off the current main). Functionally this is that work, reconstructed cleanly on top of main.

What changed

  • csrc/inc/gpu_vmm.hpp (new, ~255 lines) — a compile-time HIP/CUDA dispatch layer. Everything calls a vmm:: namespace (address_reserve, mem_create, mem_map, set_access, mem_unmap, …) that resolves to hipMem* or cuMem* based on KVCACHED_USE_HIP / KVCACHED_USE_CUDA.
  • cuda_utils.hppgpu_utils.hpp — renamed/generalized (logging + device helpers), CUDA-specific bits moved behind the abstraction.
  • csrc/ rewiredallocator, ftensor, page, page_allocator, torch_bindings, mem_info_tracker now go through gpu_vmm.hpp and at::/c10:: types instead of CUDA-only calls.
  • setup.py — auto-detects the backend from torch.version.hip vs torch.version.cuda; builds a CppExtension linking amdhip64 on ROCm, CUDAExtension linking cuda on NVIDIA.
  • Integration — vLLM/SGLang adapters accept hip device strings (PyTorch-ROCm masquerades GPUs as cuda, but the asserts are now backend-agnostic: cuda/hip).
  • SGLang 0.4.9–0.5.12 compatversion_utils falls back to package metadata when sglang.__version__ is absent (so the patches don't silently no-op), and the scheduler_memory_leak patch generalizes across SGLang's old/new leak-check layouts while leaving the req_to_token_pool check intact. (Also submitted standalone as the sglang-0512-compat PR.)
  • Layout default on ROCm — auto-default KVCACHED_CONTIGUOUS_LAYOUT=false on HIP (see note below); README updated.
  • benchmarks/bench_vmm — now cross-platform via the same abstraction (make for CUDA, make KVCACHED_BACKEND=hip for AMD).
  • tests/test_elastic_serving.py (new) — e2e elasticity check (grow/shrink under load), complements the manager-level test_kvcache_manager.py.

Why non-contiguous layout is the ROCm default

kvcached supports two KV layouts. The historical default, contiguous, packs all layers into one interleaved tensor, so a per-layer view is strided (is_contiguous=False). vLLM's ROCm paged-attention backend slices the KV cache with .view() + paged kernels that assume a standard contiguous per-layer stride, so the strided views produce wrong output on AMD. The non-contiguous layout gives each layer its own standard contiguous tensor, which the ROCm kernels handle correctly. (On NVIDIA, FlashAttention/FlashInfer use stride-tolerant unbind/varlen paths, so contiguous works there.) We therefore auto-select non-contiguous on HIP; it remains overridable via KVCACHED_CONTIGUOUS_LAYOUT.

Validation (AMD MI300X, ROCm 7.0)

Area Result
Builds on ROCm setup.py HIP path compiles the whole csrc/
VMM ops on HIP bench_vmm (HIP) runs; mem_map p50 ~4 µs, mem_unmap ~82 µs
Correctness (dense) ✅ MLA (DeepSeek-V2-Lite) and GPT-OSS output md5 == no-kvcached baseline
SGLang serving ✅ all 7 SGLang patches apply; output md5 == baseline (SGLang 0.5.12.post1)
Elasticity test_elastic_serving.py: mapped KV grows (hipMemMap) under load and shrinks (hipMemUnmap) on free; output unchanged
CUDA non-regression ✅ unchanged on NVIDIA (DGX Spark)

Environment

AMD — primary validation (MI300X):

Component Value
Host OS / kernel Ubuntu 24.04.4 LTS · 6.8.0-106-generic · x86_64
GPU 1× AMD Instinct MI300X (PCI 0x74b5), 192 GiB HBM3
ROCm 7.0.0
vLLM image rocm/vllm:rocm7.0.0_vllm_0.11.2_20251210 — Python 3.12.12, PyTorch 2.9.0a0+git1c57644 (torch.version.hip 7.0.51831-a3e329ad8), vLLM 0.11.2.dev673+g839868462
SGLang image lmsysorg/sglang:v0.5.12.post1-rocm700-mi30x — Python 3.10.12, PyTorch 2.9.0a0+git7bcbafe (torch.version.hip 7.0.51831-a3e329ad8), SGLang 0.5.12.post1

NVIDIA — CUDA non-regression (Spark):

Component Value
Machine NVIDIA DGX Spark (GB10, aarch64, ~122 GiB unified memory)
GPU NVIDIA GB10, compute capability sm_121 (12.1)
CUDA toolkit 13.0 (/usr/local/cuda)
PyTorch 2.10.0+cu130 (torch.version.hip = None)
vLLM 0.19.2.dev0+gb1388b1fb.d20260422 (source build)
Attention backend FLASH_ATTN (FlashAttention v2)

Known limitations

  • Hybrid / Mamba models (e.g. Falcon-H1) fail on ROCm with a Triton kernel error (arange's range must be a power of 2) that reproduces on the unmodified engine too — an engine/Triton-on-ROCm issue, not a kvcached regression.
  • TP/PP > 1 not tested — validation was on a single MI300X.
  • Cross-engine "giveback under pressure" (one engine releasing KV so another can grow) is not independently validated; single-engine grow/shrink is.

Scope & next steps

This PR establishes portability and correctness on AMD (see the validation table) — it deliberately does not include performance benchmarking. AMD benchmarking is the next step and is tracked separately on the amd-benchmark branch (serving overhead vs vanilla, multi-instance elastic sharing, and VMM-op latency), kept out of this PR to keep it focused on the port itself.

Note

The SGLang-compat fixes (version detection + leak-check) are committed in this branch, so it serves SGLang on AMD standalone — no external dependency. The same fixes are also submitted as the focused sglang-0512-compat PR#353; the two carry identical changes to those files, so whichever merges second simply re-applies the same lines.

Commits

  • feat: support AMD GPUs (ROCm/HIP) via gpu_vmm abstraction
  • feat(amd): accept hip device strings in sglang/vllm integration
  • fix(amd): default to non-contiguous KV layout on ROCm
  • test(amd): add e2e KV-cache elasticity-under-load test
  • fix(sglang): version detection + refactored leak check (SGLang 0.5.11+)

RixinLiu and others added 5 commits June 2, 2026 14:45
Route all GPU virtual-memory calls through a new compile-time HIP/CUDA
abstraction so the same code builds and runs on both NVIDIA (CUDA driver API)
and AMD (HIP runtime).

- csrc/inc/gpu_vmm.hpp: new backend-neutral VMM wrappers, dispatched by
  KVCACHED_USE_HIP / KVCACHED_USE_CUDA; adds mem_get_info + device_synchronize.
- cuda_utils.hpp -> gpu_utils.hpp: check macros route to gpu_vmm::check; keeps
  the LOGGER stack.
- page/ftensor/allocator/page_allocator/torch_bindings: use gpu_vmm and lower
  torch:: types to c10::/at:: (drop the torch/extension.h umbrella).
- setup.py: detect torch.version.hip; build CppExtension(+amdhip64) for ROCm,
  CUDAExtension(+cuda) for NVIDIA.
- bench_vmm: build for either backend (make KVCACHED_BACKEND=hip).

Verified on AMD Instinct MI300X (ROCm 7.2): bench_vmm, full extension build,
and a python smoke test (init -> create -> map -> GPU r/w -> unmap) all pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Make the elastic pools/allocators accept both 'cuda' and 'hip' device strings
via _is_supported_gpu_device (no functional change on ROCm, which reports
'cuda'). Also sync the bench_vmm README for the HIP backend and drop the
orphaned bench cuda_utils.hpp.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
vLLM's ROCm attention backend (split_kv_cache + paged kernels) cannot read the
strided per-layer KV tensors that kvcached's contiguous (compound-page) layout
produces; CUDA's FlashAttention/FlashInfer tolerate it. Auto-default
CONTIGUOUS_LAYOUT=false when torch is a HIP build (explicit env still wins) so
AMD is correct out of the box, plus a README note.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Drives the vLLM offline engine with kvcached and watches the mapped KV
footprint via the /dev/shm IPC: it grows on load (mem_map) and shrinks on
free (mem_unmap) with output unchanged. Complements the manager-level
test_kvcache_manager.py. Validated on AMD MI300X (ROCm 7.0); runs on
NVIDIA too (device "cuda:0").

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…0.5.12)

Bundles the SGLang-compat fixes (also submitted standalone as the
sglang-0512-compat PR) so SGLang-on-AMD works from this branch alone:

- version_utils: fall back to importlib.metadata when sglang exposes no
  module-level __version__ (source builds), so the patches don't silently
  no-op.
- patches: generalize the scheduler_memory_leak patch across SGLang's
  leak-check layouts (old single method / new SchedulerRuntimeCheckerMixin),
  skipping the req_to_token_pool-specific check kvcached must not silence.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 3, 2026 13:29
samples needed an element type; seg_name=[None] inferred as list[None],
which made seg_name[0] non-indexable under mypy. Verified against the CI
mypy-3.10 hook (and ruff/isort/codespell) locally.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a compile-time GPU backend abstraction so kvcached can run on AMD GPUs via ROCm/HIP while preserving the existing CUDA path, and updates integrations/docs/tests accordingly.

Changes:

  • Introduces a HIP/CUDA VMM dispatch layer (gpu_vmm.hpp) and rewires core C++ components to use it.
  • Updates build tooling (setup.py, benchmark Makefile) to select/link the correct backend (CUDA vs ROCm/HIP).
  • Extends integrations (vLLM/SGLang) and adds an end-to-end elasticity-under-load script/test plus documentation for ROCm’s default KV layout.

Reviewed changes

Copilot reviewed 26 out of 26 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test_elastic_serving.py Adds an end-to-end script/test to observe KV mapping growth/shrink under load via IPC stats.
setup.py Detects CUDA vs ROCm from PyTorch and builds/links with the appropriate extension type and libraries.
README.md Documents ROCm defaulting to non-contiguous KV layout and rationale.
kvcached/utils.py Defaults KV layout to non-contiguous on ROCm unless explicitly overridden.
kvcached/integration/vllm/interfaces.py Adjusts GPU availability assertion message (integration path touched).
kvcached/integration/version_utils.py Adds metadata fallback for version detection when module attributes are missing.
kvcached/integration/sglang/patches.py Broadens device-string acceptance and generalizes scheduler leak-check patching.
kvcached/integration/sglang/interfaces.py Adjusts GPU availability assertion message (integration path touched).
csrc/torch_bindings.cpp Switches to at::Tensor and updates Torch/pybind includes for bindings.
csrc/page.cpp Ports GPU page allocation/mapping to the new backend-agnostic VMM layer.
csrc/page_allocator.cpp Uses backend-agnostic mem-info and sync calls during unmap/availability computations.
csrc/inc/torch_utils.hpp Changes dtype helpers to c10::ScalarType declarations and header includes.
csrc/inc/page.hpp Replaces CUDA-specific types with backend-agnostic VMM handle/access helpers.
csrc/inc/page_allocator.hpp Removes CUDA-specific includes and adjusts headers for new usage.
csrc/inc/mem_info_tracker.hpp Updates to use generalized GPU utilities header.
csrc/inc/impl/torch_utils.ipp Converts dtype mapping helpers to c10::ScalarType.
csrc/inc/gpu_vmm.hpp New HIP/CUDA abstraction for VMM operations and error handling.
csrc/inc/gpu_utils.hpp Generalizes logging/check macros to route through the new VMM abstraction.
csrc/inc/ftensor.hpp Replaces torch::* types with at::Tensor / c10::* types.
csrc/inc/allocator.hpp Replaces torch::* types with at::Tensor / c10::* types and renames init helper.
csrc/ftensor.cpp Ports virtual address reservation/mapping/unmapping to backend-agnostic VMM calls.
csrc/allocator.cpp Ports allocator init and tensor creation paths to backend-agnostic VMM calls/types.
benchmarks/bench_vmm/README.md Updates benchmark docs to reflect CUDA+HIP support and renamed operations.
benchmarks/bench_vmm/Makefile Adds backend selection (cuda vs hip) and links the correct driver library.
benchmarks/bench_vmm/cuda_utils.hpp Removes CUDA-only utilities now superseded by shared GPU utilities.
benchmarks/bench_vmm/bench_vmm.cpp Ports benchmark implementation to the backend-agnostic VMM utilities.
Comments suppressed due to low confidence (2)

kvcached/integration/vllm/interfaces.py:200

  • device is now allowed to be a HIP-style string (e.g. "hip:0"), but this function still passes it directly to torch.cuda.get_device_properties(...) and to create_kv_tensors(...). PyTorch does not recognize a hip device type (ROCm GPUs are exposed as cuda), so this will raise and/or break the C++ extension device parsing when device starts with "hip".
    assert torch.cuda.is_available(), "GPU backend is not available via torch.cuda."

    # --- Compute per-layer memory budget and number of blocks ---
    gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory
    gpu_mem_bytes_per_layer_k_or_v = gpu_mem_bytes // num_layers // num_k_or_v

kvcached/integration/sglang/interfaces.py:100

  • This function now accepts HIP-style device strings ("hip:0"), but still passes device directly to torch.cuda.get_device_properties(...) and create_kv_tensors(...). PyTorch ROCm builds expose HIP devices as cuda, so a hip:* device string will not be understood by either PyTorch or the C++ extension’s c10::Device parsing.
    assert torch.cuda.is_available(), "GPU backend is not available via torch.cuda."

    # SGLang named it "page" to be consistent with PagedAttention. But we call
    # it "block" to distinguish a KV cache block and a physical memory page.
    block_size = page_size

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/page_allocator.cpp
Comment thread csrc/inc/torch_utils.hpp
- page_allocator: check mem_get_info status via CHECK_GPU instead of
  discarding it with (void); a failed call previously computed a page
  count from uninitialized sizes. Zero-init the sizes too.
- torch_utils.hpp: include <pybind11/pybind11.h> so the header is
  self-contained -- it declares functions taking py::object but relied
  on the includer pulling in pybind11 first.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 6 comments.

Comment thread kvcached/utils.py
Comment thread kvcached/integration/sglang/interfaces.py
Comment thread kvcached/integration/vllm/interfaces.py
Comment thread tests/test_elastic_serving.py
Comment thread tests/test_elastic_serving.py
Comment thread csrc/torch_bindings.cpp
The integration accepts `hip` device strings, but PyTorch-ROCm and the C++
extension (c10::Device) address AMD GPUs as `cuda`, so a literal `hip:0`
would fail in torch.cuda.get_device_properties / create_kv_tensors / the
C++ init. Add a shared normalize_gpu_device() helper and apply it at every
device entry point in both the vLLM and SGLang integrations (init_kvcached
+ alloc paths). No-op for the `cuda` strings the engines actually pass on
ROCm; verified with a vLLM generate smoke on MI300X.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jiarong0907 jiarong0907 requested a review from cui36 June 3, 2026 16:06
@cui36
Copy link
Copy Markdown
Collaborator

cui36 commented Jun 4, 2026

Reproduced the experiments on AMD. LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants